{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Coding Exercise: Model Adaption Meta Learning(MiniImageNet Dataset)\n",
    "\n",
    "In this tutorial, we will implement Model Adaption Meta Learning, to learn Omniglot classes with few examples.\n",
    "If you recall Model Agnostic Meta Learning, consists of 2 loops:\n",
    "1. To learn parameters for all tasks.\n",
    "2. To learn task specific parameters\n",
    "\n",
    "MAML algorithm aims to learn such parameters that can adapt to new tasks with very few examples. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"Images/parameters.png\" width=\"600\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "<img src=\"Images/maml_algo.png\" width=\"500\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install torchmeta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchmeta.modules import (MetaModule, MetaSequential, MetaConv2d, MetaConv3d,\n",
    "                               MetaBatchNorm2d, MetaLinear)\n",
    "from torchmeta.modules.utils import get_subdict\n",
    "from collections import OrderedDict\n",
    "import os\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "import torchmeta\n",
    "from torchmeta.utils.data import BatchMetaDataLoader \n",
    "from torchmeta.datasets import MiniImagenet\n",
    "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n",
    "from torchvision.transforms import Compose, Resize, ToTensor\n",
    "from torchmeta.utils.data import BatchMetaDataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Omniglot dataset\n",
    "\n",
    "Download Mini ImageNet Dataset: https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def conv3x3(in_channels, out_channels, **kwargs):\n",
    "    return MetaSequential(\n",
    "        MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),\n",
    "        MetaBatchNorm2d(out_channels, momentum=1., track_running_stats=False),\n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(2)\n",
    "    )\n",
    "\n",
    "class ConvolutionalNeuralNetwork(MetaModule):\n",
    "    def __init__(self, in_channels, out_features, hidden_size=64):\n",
    "        super(ConvolutionalNeuralNetwork, self).__init__()\n",
    "        self.in_channels = in_channels\n",
    "        self.out_features = out_features\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.features = MetaSequential(\n",
    "            conv3x3(in_channels, hidden_size),\n",
    "            conv3x3(hidden_size, hidden_size),\n",
    "            conv3x3(hidden_size, hidden_size),\n",
    "            # conv3x3(hidden_size, hidden_size),\n",
    "            # conv3x3(hidden_size, hidden_size),\n",
    "            conv3x3(hidden_size, hidden_size)\n",
    "        )\n",
    "\n",
    "        self.classifier = MetaLinear(hidden_size, out_features)\n",
    "\n",
    "    def forward(self, inputs, params=None):\n",
    "        features = self.features(inputs, params=get_subdict(params, 'features'))\n",
    "        features = features.view((features.size(0), -1))\n",
    "        logits = self.classifier(features, params=get_subdict(params, 'classifier'))\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def update_parameters(model, loss, step_size=0.5, first_order=False):\n",
    "    grads = torch.autograd.grad(loss, model.meta_parameters(),\n",
    "        create_graph=not first_order)\n",
    "\n",
    "    params = OrderedDict()\n",
    "    for (name, param), grad in zip(model.meta_named_parameters(), grads):\n",
    "        params[name] = param - step_size * grad\n",
    "\n",
    "    return params\n",
    "\n",
    "def get_accuracy(logits, targets):\n",
    "    _, predictions = torch.max(logits, dim=-1)\n",
    "    return torch.mean(predictions.eq(targets).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  set, S, consisting of n examples each from k different unseen classes\n",
    "device = torch.device('cuda')\n",
    "num_ways = 5\n",
    "step_size = 0.4\n",
    "batch_size = 16\n",
    "num_batches = 2000\n",
    "first_order = True\n",
    "hidden_size = 64\n",
    "output_folder = './'\n",
    "num_classes_per_task = 5\n",
    "\n",
    "dataset = MiniImagenet(\"data\",\n",
    "                   # Number of ways\n",
    "                   num_classes_per_task=num_classes_per_task,\n",
    "                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n",
    "                   transform=Compose([Resize(28), ToTensor()]),\n",
    "                   # Transform the labels to integers\n",
    "                   target_transform=Categorical(num_classes=5),\n",
    "                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n",
    "                   class_augmentations=[Rotation([90, 180, 270])],\n",
    "                   meta_train=True,\n",
    "                   download=True)\n",
    "dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)\n",
    "dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = ConvolutionalNeuralNetwork(3, num_ways, hidden_size=hidden_size)\n",
    "model.to(device=device)\n",
    "model.train()\n",
    "meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Training loop\n",
    "with tqdm(dataloader, total=num_batches) as pbar:\n",
    "    for batch_idx, batch in enumerate(pbar):\n",
    "        model.zero_grad()\n",
    "\n",
    "        train_inputs, train_targets = batch['train']\n",
    "        train_inputs = train_inputs.to(device=device)\n",
    "        train_targets = train_targets.to(device=device)\n",
    "\n",
    "        test_inputs, test_targets = batch['test']\n",
    "        test_inputs = test_inputs.to(device=device)\n",
    "        test_targets = test_targets.to(device=device)\n",
    "\n",
    "        outer_loss = torch.tensor(0., device=device)\n",
    "        accuracy = torch.tensor(0., device=device)\n",
    "        \n",
    "        \n",
    "        for task_idx, (train_input, train_target, test_input,\n",
    "                test_target) in enumerate(zip(train_inputs, train_targets,\n",
    "                test_inputs, test_targets)):\n",
    "            train_logit = model(train_input)\n",
    "            inner_loss = F.cross_entropy(train_logit, train_target)\n",
    "\n",
    "            model.zero_grad()\n",
    "            params = update_parameters(model, inner_loss,\n",
    "                step_size=step_size, first_order=first_order)\n",
    "\n",
    "            test_logit = model(test_input, params=params)\n",
    "            outer_loss += F.cross_entropy(test_logit, test_target)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                accuracy += get_accuracy(test_logit, test_target)\n",
    "\n",
    "        outer_loss.div_(batch_size)\n",
    "        accuracy.div_(batch_size)\n",
    "\n",
    "        outer_loss.backward()\n",
    "        meta_optimizer.step()\n",
    "\n",
    "        pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))\n",
    "        if batch_idx >= num_batches:\n",
    "            break\n",
    "\n",
    "# Save model\n",
    "if output_folder is not None:\n",
    "    filename = os.path.join(output_folder, 'maml_omniglot_'\n",
    "        '{0}shot_{1}way.pt'.format(num_classes_per_task, num_ways))\n",
    "    with open(filename, 'wb') as f:\n",
    "        state_dict = model.state_dict()\n",
    "        torch.save(state_dict, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
